Multi-Class Classification Using a scikit Decision Tree

I’ve been reviewing the scikit-learn (scikit for short) library for several months, so I figured I’d do a multi-class decision tree classification example. Before I go any further, let me comment that machine learning beginners are often seduced by the visual elegance of decision trees, but tree classifiers have several weaknesses.

I used one of my standard datasets for multi-class classification. The data looks like:

 1   0.24   1 0 0   0.2950   2
 0   0.39   0 0 1   0.5120   0
 1   0.63   0 1 0   0.7580   1
 0   0.36   1 0 0   0.4450   2
. . . 

Each line of data represents a person. The fields are sex (male = 0, female = 1), age (normalized by dividing by 100), State (Michigan = 100, Nebraska = 010, Oklahoma = 001), annual income (divided by 100,000), and politics type (conservative = 0, moderate = 1, liberal = 2). The goal is to predict the politics type of a person from their sex, age, State, and income.

It isn’t necessary to normalize age and income. Converting categorical predictors like State is conceptually tricky, but the bottom line is that in most scenarios it’s best to one-hot encode. For binary predictor variables, I recommend using 0 or 1 encoding, but again, there are a lot of subtle details.

The key lines of code are:

  import numpy as np 
  from sklearn import tree 

  md = 3  # max depth
  print("Creating decision tree max_depth=" + str(md))
  model = tree.DecisionTreeClassifier(max_depth=md) 
  model.fit(train_x, train_y)
  print("Done ")

Decision trees are highly sensitive to overfitting. If you set a large max_depth you can get 100% classification on training data but the accuracy on test data and new previously unseen data will likely be very poor.

The accuracy of the model can be displayed using the built-in score() method or indirectly in the form of a confusion matrix:

  from sklearn.metrics import confusion_matrix
  y_predicteds = model.predict(test_x)
  cm = confusion_matrix(test_y, y_predicteds)
  print("Confusion matrix: \n")
  # print(cm)  # no formatting
  show_confusion(cm)  # custom formatting

There are several ways to visualize a trained tree classifier. The model can be displayed as text pseudo-code like this:

  pseudo = tree.export_text(model,
    ["sex", "age",
    "state0", "state1", "state2",
    "income"])
  print("Model in pseudo-code: ")
  print(pseudo)

The wackiness of the pseudo-code points out a weakness of decision tress — they’re highly sensitive to changes in training data.

The tree model can be displayed graphically using the plot_tree() method like so:

  import matplotlib.pyplot as plt
  plt.figure(figsize=(14,8),
    tight_layout=True)  # w,h inches
  tree.plot_tree(model,
    feature_names=["sex", "age",
      "state0", "state1", "state2",
      "income"],
    class_names=["con", "mod", "lib"],
    fontsize=8)
  plt.show()

Anyway, the demo was a good refresher for me.



One of my favorite series of science fiction books is the Mars series by author Edgar Rice Burroughs. The fictional world has a lot of politics and races: Red Martians (human-like), Green (fierce, 15-feet tall with six arms), Yellow Martians (secretive), White Martians (predecessors to Red Martians), and Black Martians (evil).

“A Fighting Man of Mars” is the seventh book in the series. It was first published in book form in 1931. The book tells the story of low-born soldier Tan Hadron who sets of to rescue snooty noblewoman Sanoma. He has many adventures and recues and falls in love with beautiful slave Tavia — who turn out to be a princess.

Left: Cover art by Robert Abbett. Center: Cover art by Michael Whelan. Right: Cover art by Roy Krenkel.


Demo code. Replace “lte” with Boolean less-than-or-equal operator. The data is also listed below.

# people_politics_tree_sckit.py

# predict politics (0 = con, 1 = mod, 2 = lib) 
# from sex, age, state, income

# sex  age    state    income   politics
#  0   0.27   0  1  0   0.7610   2
#  1   0.19   0  0  1   0.6550   0
# sex: 0 = male, 1 = female
# state: michigan = 100, nebraska = 010, oklahoma = 001
# politics: conservative, moderate, liberal

# Anaconda3-2020.02  Python 3.7.6  scikit 0.22.1
# Windows 10/11

import numpy as np 
from sklearn import tree 

# ---------------------------------------------------------

def tree_to_pseudo(model, feature_names):
  # custom function to display tree model pseudo-code
  left = model.tree_.children_left
  right = model.tree_.children_right
  threshold = model.tree_.threshold
  features = [feature_names[i] for i in model.tree_.feature]
  value = model.tree_.value

  def recurse(left, right, threshold, features, node, depth=0):
    indent = "  " * depth
    if (threshold[node] != -2):
      v = "%0.4f" % threshold[node]
      print(indent,"if ( " + features[node] + " lte " +
        str(v) + " ) {")
      
      if left[node] != -1:
        recurse(left, right, threshold, features, \
          left[node], depth+1)
        print(indent,"} else {")
        if right[node] != -1:
          recurse(left, right, threshold, features, \
            right[node], depth+1)
        print(indent,"}")
    else:
      idx = np.argmax(value[node])
      # print(indent,"return " + str(value[node]))
      print(indent,"return " + str(model.classes_[idx]))

  recurse(left, right, threshold, features, 0)

# ---------------------------------------------------------

def show_confusion(cm):
  dim = len(cm)
  mx = np.max(cm)             # largest count in cm
  wid = len(str(mx)) + 1      # width to print
  fmt = "%" + str(wid) + "d"  # like "%3d"
  for i in range(dim):
    print("actual   ", end="")
    print("%3d:" % i, end="")
    for j in range(dim):
      print(fmt % cm[i][j], end="")
    print("")
  print("------------")
  print("predicted    ", end="")
  for j in range(dim):
    print(fmt % j, end="")
  print("")

# ---------------------------------------------------------

def main():
  # 0. get ready
  print("\nBegin scikit decision tree example ")
  print("Predict politics from sex, age, State, income ")
  np.random.seed(0)
  np.set_printoptions(precision=4, suppress=True)

  # sex  age    state    income   politics
  #  0   0.27   0  1  0   0.7610   2
  #  1   0.19   0  0  1   0.6550   0

  # 1. load data
  print("\nLoading data into memory ")
  train_file = ".\\Data\\people_train.txt"
  train_xy = np.loadtxt(train_file, usecols=range(0,7),
    delimiter="\t", comments="#",  dtype=np.float32) 
  train_x = train_xy[:,0:6]
  train_y = train_xy[:,6].astype(int)

  test_file = ".\\Data\\people_test.txt"
  test_xy = np.loadtxt(test_file, usecols=range(0,7),
    delimiter="\t", comments="#",  dtype=np.float32) 
  test_x = test_xy[:,0:6]
  test_y = test_xy[:,6].astype(int)

  print("\nTraining data:")
  print(train_x[0:4])
  print(". . . \n")
  print(train_y[0:4])
  print(". . . ")

  # 2. create and train 
  md = 3
  print("\nCreating decision tree max_depth=" + str(md))
  model = tree.DecisionTreeClassifier(max_depth=md) 
  model.fit(train_x, train_y)
  print("Done ")

  # 3. evaluate
  acc_train = model.score(train_x, train_y)
  print("\nAccuracy on train = %0.4f " % acc_train)
  acc_test = model.score(test_x, test_y)
  print("Accuracy on test = %0.4f " % acc_test)

  # 3b. display formatted confusion matrix
  from sklearn.metrics import confusion_matrix
  y_predicteds = model.predict(test_x)
  cm = confusion_matrix(test_y, y_predicteds)
  print("\nConfusion matrix: \n")
  show_confusion(cm)

  # 4a. visualize using custom function
  # print("\nModel in pseudo-code: ")
  # tree_to_pseudo(model, ["sex", "age",
  #   "state0", "state1", "state2",
  #  "income"])

  # 4b. use built-in export_text()
  pseudo = tree.export_text(model,
    ["sex", "age",
    "state0", "state1", "state2",
    "income"])
  print("\nModel in pseudo-code: ")
  print(pseudo)

  # 4c. use built-in plot_tree()
  import matplotlib.pyplot as plt
  plt.figure(figsize=(14,8),
    tight_layout=True)  # w,h inches
  tree.plot_tree(model,
    feature_names=["sex", "age",
      "state0", "state1", "state2",
      "income"],
    class_names=["con", "mod", "lib"],
    fontsize=8)
  plt.show()

  # 5. use model
  print("\nPredict for: M 35 Nebraska $55K ")
  X = np.array([[0, 0.35, 0,1,0, 0.5500]],
    dtype=np.float32)
  probs = model.predict_proba(X)
  print("\nPrediction pseudo-probs: ")
  print(probs)

  politic = model.predict(X)
  print("\nPredicted class: ")
  print(politic)

  # 6. TODO: save model using pickle
  # import pickle
  # print("Saving trained tree model ")
  # path = ".\\Models\\tree_scikit_model.sav"
  # pickle.dump(model, open(path, "wb"))

  # use saved model
  # X = np.array([[0, 0.35, 0,1,0, 0.5500]],
  #   dtype=np.float32)
  # with open(path, 'rb') as f:
  #   loaded_model = pickle.load(f)
  # pa = loaded_model.predict_proba(X)
  # print(pa)

  print("\nEnd scikit decision tree demo ")

if __name__ == "__main__":
  main()

Training data. Replace commas with tab characters or modify program.

# people_train.txt
# sex (M=0, F=1), age (div 100)
# state (michigan = 100, nebraska = 010,
#  oklahoma = 001)
# income (div 100,000)
# politics (con = 0, mod = 1, lib = 2)
#
1,0.24,1,0,0,0.2950,2
0,0.39,0,0,1,0.5120,1
1,0.63,0,1,0,0.7580,0
0,0.36,1,0,0,0.4450,1
1,0.27,0,1,0,0.2860,2
1,0.50,0,1,0,0.5650,1
1,0.50,0,0,1,0.5500,1
0,0.19,0,0,1,0.3270,0
1,0.22,0,1,0,0.2770,1
0,0.39,0,0,1,0.4710,2
1,0.34,1,0,0,0.3940,1
0,0.22,1,0,0,0.3350,0
1,0.35,0,0,1,0.3520,2
0,0.33,0,1,0,0.4640,1
1,0.45,0,1,0,0.5410,1
1,0.42,0,1,0,0.5070,1
0,0.33,0,1,0,0.4680,1
1,0.25,0,0,1,0.3000,1
0,0.31,0,1,0,0.4640,0
1,0.27,1,0,0,0.3250,2
1,0.48,1,0,0,0.5400,1
0,0.64,0,1,0,0.7130,2
1,0.61,0,1,0,0.7240,0
1,0.54,0,0,1,0.6100,0
1,0.29,1,0,0,0.3630,0
1,0.50,0,0,1,0.5500,1
1,0.55,0,0,1,0.6250,0
1,0.40,1,0,0,0.5240,0
1,0.22,1,0,0,0.2360,2
1,0.68,0,1,0,0.7840,0
0,0.60,1,0,0,0.7170,2
0,0.34,0,0,1,0.4650,1
0,0.25,0,0,1,0.3710,0
0,0.31,0,1,0,0.4890,1
1,0.43,0,0,1,0.4800,1
1,0.58,0,1,0,0.6540,2
0,0.55,0,1,0,0.6070,2
0,0.43,0,1,0,0.5110,1
0,0.43,0,0,1,0.5320,1
0,0.21,1,0,0,0.3720,0
1,0.55,0,0,1,0.6460,0
1,0.64,0,1,0,0.7480,0
0,0.41,1,0,0,0.5880,1
1,0.64,0,0,1,0.7270,0
0,0.56,0,0,1,0.6660,2
1,0.31,0,0,1,0.3600,1
0,0.65,0,0,1,0.7010,2
1,0.55,0,0,1,0.6430,0
0,0.25,1,0,0,0.4030,0
1,0.46,0,0,1,0.5100,1
0,0.36,1,0,0,0.5350,0
1,0.52,0,1,0,0.5810,1
1,0.61,0,0,1,0.6790,0
1,0.57,0,0,1,0.6570,0
0,0.46,0,1,0,0.5260,1
0,0.62,1,0,0,0.6680,2
1,0.55,0,0,1,0.6270,0
0,0.22,0,0,1,0.2770,1
0,0.50,1,0,0,0.6290,0
0,0.32,0,1,0,0.4180,1
0,0.21,0,0,1,0.3560,0
1,0.44,0,1,0,0.5200,1
1,0.46,0,1,0,0.5170,1
1,0.62,0,1,0,0.6970,0
1,0.57,0,1,0,0.6640,0
0,0.67,0,0,1,0.7580,2
1,0.29,1,0,0,0.3430,2
1,0.53,1,0,0,0.6010,0
0,0.44,1,0,0,0.5480,1
1,0.46,0,1,0,0.5230,1
0,0.20,0,1,0,0.3010,1
0,0.38,1,0,0,0.5350,1
1,0.50,0,1,0,0.5860,1
1,0.33,0,1,0,0.4250,1
0,0.33,0,1,0,0.3930,1
1,0.26,0,1,0,0.4040,0
1,0.58,1,0,0,0.7070,0
1,0.43,0,0,1,0.4800,1
0,0.46,1,0,0,0.6440,0
1,0.60,1,0,0,0.7170,0
0,0.42,1,0,0,0.4890,1
0,0.56,0,0,1,0.5640,2
0,0.62,0,1,0,0.6630,2
0,0.50,1,0,0,0.6480,1
1,0.47,0,0,1,0.5200,1
0,0.67,0,1,0,0.8040,2
0,0.40,0,0,1,0.5040,1
1,0.42,0,1,0,0.4840,1
1,0.64,1,0,0,0.7200,0
0,0.47,1,0,0,0.5870,2
1,0.45,0,1,0,0.5280,1
0,0.25,0,0,1,0.4090,0
1,0.38,1,0,0,0.4840,0
1,0.55,0,0,1,0.6000,1
0,0.44,1,0,0,0.6060,1
1,0.33,1,0,0,0.4100,1
1,0.34,0,0,1,0.3900,1
1,0.27,0,1,0,0.3370,2
1,0.32,0,1,0,0.4070,1
1,0.42,0,0,1,0.4700,1
0,0.24,0,0,1,0.4030,0
1,0.42,0,1,0,0.5030,1
1,0.25,0,0,1,0.2800,2
1,0.51,0,1,0,0.5800,1
0,0.55,0,1,0,0.6350,2
1,0.44,1,0,0,0.4780,2
0,0.18,1,0,0,0.3980,0
0,0.67,0,1,0,0.7160,2
1,0.45,0,0,1,0.5000,1
1,0.48,1,0,0,0.5580,1
0,0.25,0,1,0,0.3900,1
0,0.67,1,0,0,0.7830,1
1,0.37,0,0,1,0.4200,1
0,0.32,1,0,0,0.4270,1
1,0.48,1,0,0,0.5700,1
0,0.66,0,0,1,0.7500,2
1,0.61,1,0,0,0.7000,0
0,0.58,0,0,1,0.6890,1
1,0.19,1,0,0,0.2400,2
1,0.38,0,0,1,0.4300,1
0,0.27,1,0,0,0.3640,1
1,0.42,1,0,0,0.4800,1
1,0.60,1,0,0,0.7130,0
0,0.27,0,0,1,0.3480,0
1,0.29,0,1,0,0.3710,0
0,0.43,1,0,0,0.5670,1
1,0.48,1,0,0,0.5670,1
1,0.27,0,0,1,0.2940,2
0,0.44,1,0,0,0.5520,0
1,0.23,0,1,0,0.2630,2
0,0.36,0,1,0,0.5300,2
1,0.64,0,0,1,0.7250,0
1,0.29,0,0,1,0.3000,2
0,0.33,1,0,0,0.4930,1
0,0.66,0,1,0,0.7500,2
0,0.21,0,0,1,0.3430,0
1,0.27,1,0,0,0.3270,2
1,0.29,1,0,0,0.3180,2
0,0.31,1,0,0,0.4860,1
1,0.36,0,0,1,0.4100,1
1,0.49,0,1,0,0.5570,1
0,0.28,1,0,0,0.3840,0
0,0.43,0,0,1,0.5660,1
0,0.46,0,1,0,0.5880,1
1,0.57,1,0,0,0.6980,0
0,0.52,0,0,1,0.5940,1
0,0.31,0,0,1,0.4350,1
0,0.55,1,0,0,0.6200,2
1,0.50,1,0,0,0.5640,1
1,0.48,0,1,0,0.5590,1
0,0.22,0,0,1,0.3450,0
1,0.59,0,0,1,0.6670,0
1,0.34,1,0,0,0.4280,2
0,0.64,1,0,0,0.7720,2
1,0.29,0,0,1,0.3350,2
0,0.34,0,1,0,0.4320,1
0,0.61,1,0,0,0.7500,2
1,0.64,0,0,1,0.7110,0
0,0.29,1,0,0,0.4130,0
1,0.63,0,1,0,0.7060,0
0,0.29,0,1,0,0.4000,0
0,0.51,1,0,0,0.6270,1
0,0.24,0,0,1,0.3770,0
1,0.48,0,1,0,0.5750,1
1,0.18,1,0,0,0.2740,0
1,0.18,1,0,0,0.2030,2
1,0.33,0,1,0,0.3820,2
0,0.20,0,0,1,0.3480,0
1,0.29,0,0,1,0.3300,2
0,0.44,0,0,1,0.6300,0
0,0.65,0,0,1,0.8180,0
0,0.56,1,0,0,0.6370,2
0,0.52,0,0,1,0.5840,1
0,0.29,0,1,0,0.4860,0
0,0.47,0,1,0,0.5890,1
1,0.68,1,0,0,0.7260,2
1,0.31,0,0,1,0.3600,1
1,0.61,0,1,0,0.6250,2
1,0.19,0,1,0,0.2150,2
1,0.38,0,0,1,0.4300,1
0,0.26,1,0,0,0.4230,0
1,0.61,0,1,0,0.6740,0
1,0.40,1,0,0,0.4650,1
0,0.49,1,0,0,0.6520,1
1,0.56,1,0,0,0.6750,0
0,0.48,0,1,0,0.6600,1
1,0.52,1,0,0,0.5630,2
0,0.18,1,0,0,0.2980,0
0,0.56,0,0,1,0.5930,2
0,0.52,0,1,0,0.6440,1
0,0.18,0,1,0,0.2860,1
0,0.58,1,0,0,0.6620,2
0,0.39,0,1,0,0.5510,1
0,0.46,1,0,0,0.6290,1
0,0.40,0,1,0,0.4620,1
0,0.60,1,0,0,0.7270,2
1,0.36,0,1,0,0.4070,2
1,0.44,1,0,0,0.5230,1
1,0.28,1,0,0,0.3130,2
1,0.54,0,0,1,0.6260,0

Test data. Replace commas with tab characters or modify program.

0,0.51,1,0,0,0.6120,1
0,0.32,0,1,0,0.4610,1
1,0.55,1,0,0,0.6270,0
1,0.25,0,0,1,0.2620,2
1,0.33,0,0,1,0.3730,2
0,0.29,0,1,0,0.4620,0
1,0.65,1,0,0,0.7270,0
0,0.43,0,1,0,0.5140,1
0,0.54,0,1,0,0.6480,2
1,0.61,0,1,0,0.7270,0
1,0.52,0,1,0,0.6360,0
1,0.30,0,1,0,0.3350,2
1,0.29,1,0,0,0.3140,2
0,0.47,0,0,1,0.5940,1
1,0.39,0,1,0,0.4780,1
1,0.47,0,0,1,0.5200,1
0,0.49,1,0,0,0.5860,1
0,0.63,0,0,1,0.6740,2
0,0.30,1,0,0,0.3920,0
0,0.61,0,0,1,0.6960,2
0,0.47,0,0,1,0.5870,1
1,0.30,0,0,1,0.3450,2
0,0.51,0,0,1,0.5800,1
0,0.24,1,0,0,0.3880,1
0,0.49,1,0,0,0.6450,1
1,0.66,0,0,1,0.7450,0
0,0.65,1,0,0,0.7690,0
0,0.46,0,1,0,0.5800,0
0,0.45,0,0,1,0.5180,1
0,0.47,1,0,0,0.6360,0
0,0.29,1,0,0,0.4480,0
0,0.57,0,0,1,0.6930,2
0,0.20,1,0,0,0.2870,2
0,0.35,1,0,0,0.4340,1
0,0.61,0,0,1,0.6700,2
0,0.31,0,0,1,0.3730,1
1,0.18,1,0,0,0.2080,2
1,0.26,0,0,1,0.2920,2
0,0.28,1,0,0,0.3640,2
0,0.59,0,0,1,0.6940,2
This entry was posted in Scikit. Bookmark the permalink.